package service

import (
	"context"
	"testing"

	"github.com/pkg/errors"
	"github.com/stackrox/rox/central/convert/storagetov2"
	"github.com/stackrox/rox/central/convert/testutils"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/common"
	dsMock "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/datastore/mocks"
	mgrMock "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/manager/requestmgr/mocks"
	v2 "github.com/stackrox/rox/generated/api/v2"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/protoassert"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stretchr/testify/suite"
	"go.uber.org/mock/gomock"
)

func TestVulnRequestService(t *testing.T) {
	suite.Run(t, new(VulnRequestServiceTestSuite))
}

type VulnRequestServiceTestSuite struct {
	suite.Suite

	mockCtrl  *gomock.Controller
	datastore *dsMock.MockDataStore
	manager   *mgrMock.MockManager

	service Service
}

func (s *VulnRequestServiceTestSuite) SetupTest() {
	s.mockCtrl = gomock.NewController(s.T())

	s.datastore = dsMock.NewMockDataStore(s.mockCtrl)
	s.manager = mgrMock.NewMockManager(s.mockCtrl)
	s.service = New(s.datastore, s.manager)
}

func (s *VulnRequestServiceTestSuite) TestGetVulnerabilityException() {
	ctx := sac.WithAllAccess(context.Background())
	s.datastore.EXPECT().Get(ctx, "id").Return(testutils.GetTestVulnDeferralRequestFull(s.T()), true, nil)
	resp, err := s.service.GetVulnerabilityException(ctx, &v2.ResourceByID{Id: "id"})
	s.NoError(err)
	protoassert.Equal(s.T(), testutils.GetTestVulnDeferralExceptionFull(s.T()), resp.GetException())

	s.datastore.EXPECT().Get(ctx, "id").Return(nil, false, errors.New("fake"))
	resp, err = s.service.GetVulnerabilityException(ctx, &v2.ResourceByID{Id: "id"})
	s.Error(err)
	s.Nil(resp)
}

func (s *VulnRequestServiceTestSuite) TestListVulnerabilityExceptions() {
	ctx := sac.WithAllAccess(context.Background())
	s.datastore.EXPECT().SearchRawRequests(ctx, gomock.Any()).Return([]*storage.VulnerabilityRequest{testutils.GetTestVulnDeferralRequestFull(s.T())}, nil)
	resp, err := s.service.ListVulnerabilityExceptions(ctx, &v2.RawQuery{})
	s.NoError(err)
	protoassert.ElementsMatch(s.T(), []*v2.VulnerabilityException{testutils.GetTestVulnDeferralExceptionFull(s.T())}, resp.GetExceptions())
}

func (s *VulnRequestServiceTestSuite) TestCreateDeferVulnerabilityException() {
	ctx := sac.WithAllAccess(context.Background())
	req := &v2.CreateDeferVulnerabilityExceptionRequest{
		Cves:    []string{"cve1"},
		Comment: "message",
		Scope: &v2.VulnerabilityException_Scope{
			ImageScope: &v2.VulnerabilityException_Scope_Image{
				Registry: "reg",
				Remote:   "remote",
				Tag:      "tag",
			},
		},
		ExceptionExpiry: &v2.ExceptionExpiry{
			ExpiryType: v2.ExceptionExpiry_TIME,
		},
	}

	s.manager.EXPECT().Create(ctx, gomock.Any()).Return(nil)
	resp, err := s.service.CreateDeferVulnerabilityException(ctx, req)
	s.NoError(err)
	s.NotNil(resp.GetException())
	s.Equal(v2.ExceptionStatus_PENDING, resp.GetException().GetStatus())
	s.Equal(v2.VulnerabilityState_DEFERRED, resp.GetException().GetTargetState())
}

func (s *VulnRequestServiceTestSuite) TestCreateFalsePositiveVulnerabilityException() {
	ctx := sac.WithAllAccess(context.Background())
	req := &v2.CreateFalsePositiveVulnerabilityExceptionRequest{
		Cves:    []string{"cve1"},
		Comment: "message",
		Scope: &v2.VulnerabilityException_Scope{
			ImageScope: &v2.VulnerabilityException_Scope_Image{
				Registry: "reg",
				Remote:   "remote",
				Tag:      "tag",
			},
		},
	}

	s.manager.EXPECT().Create(ctx, gomock.Any()).Return(nil)
	resp, err := s.service.CreateFalsePositiveVulnerabilityException(ctx, req)
	s.NoError(err)
	s.NotNil(resp.GetException())
	s.Equal(v2.ExceptionStatus_PENDING, resp.GetException().GetStatus())
	s.Equal(v2.VulnerabilityState_FALSE_POSITIVE, resp.GetException().GetTargetState())
}

func (s *VulnRequestServiceTestSuite) TestApproveVulnerabilityException() {
	ctx := sac.WithAllAccess(context.Background())
	pendingReq := testutils.GetTestVulnDeferralRequestFull(s.T())
	req := &v2.ApproveVulnerabilityExceptionRequest{
		Id:      pendingReq.GetId(),
		Comment: "approved",
	}
	approvedReq := pendingReq.CloneVT()
	approvedReq.Status = storage.RequestStatus_APPROVED

	s.manager.EXPECT().Approve(ctx, pendingReq.GetId(), &common.VulnRequestParams{Comment: req.GetComment()}).
		Return(approvedReq, nil)

	resp, err := s.service.ApproveVulnerabilityException(ctx, req)
	s.NoError(err)
	protoassert.Equal(s.T(), storagetov2.VulnerabilityException(approvedReq), resp.GetException())
}

func (s *VulnRequestServiceTestSuite) TestDenyVulnerabilityException() {
	ctx := sac.WithAllAccess(context.Background())
	pendingReq := testutils.GetTestVulnFPRequestFull(s.T())
	req := &v2.DenyVulnerabilityExceptionRequest{
		Id:      pendingReq.GetId(),
		Comment: "denied",
	}
	deniedReq := pendingReq.CloneVT()
	deniedReq.Status = storage.RequestStatus_DENIED

	s.manager.EXPECT().Deny(ctx, pendingReq.GetId(), &common.VulnRequestParams{Comment: req.GetComment()}).
		Return(deniedReq, nil)

	resp, err := s.service.DenyVulnerabilityException(ctx, req)
	s.NoError(err)
	protoassert.Equal(s.T(), storagetov2.VulnerabilityException(deniedReq), resp.GetException())
}
